{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import logging\n",
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import ase\n",
    "import numpy as np\n",
    "\n",
    "from mcmc.system import SurfaceSystem\n",
    "from mcmc.utils.misc import get_atoms_batch\n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize test slabs and parameters. We are testing the dominant surfaces at different $\\mu_{Sr}$ values. Starting from the DL-TiO2 termination at low $\\mu_{Sr}$, we have the SL-TiO2 termination at mid $\\mu_{Sr}$, and the SL-SrO at high $\\mu_{Sr}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clear jupyter notebook default handler\n",
    "default_logger = logging.getLogger()\n",
    "default_logger.handlers.clear()\n",
    "\n",
    "# Load prepared slabs\n",
    "offset_data_path = os.path.join(\n",
    "    \"../tutorials\",\n",
    "    \"data/SrTiO3_001/nff\",\n",
    "    \"offset_data.json\",\n",
    ")\n",
    "\n",
    "ref_slab_files = [\n",
    "    \"data/SrTiO3_001/O44Sr12Ti16.cif\",\n",
    "    \"data/SrTiO3_001/O36Sr12Ti12.cif\",\n",
    "    \"data/SrTiO3_001/O40Sr16Ti12.cif\",\n",
    "]\n",
    "\n",
    "ref_slabs = [ase.io.read(f) for f in ref_slab_files]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize paths\n",
    "surface_name = \"SrTiO3_001\"\n",
    "run_folder = Path() / surface_name\n",
    "run_folder.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "try:\n",
    "    with open(offset_data_path, \"r\") as f:\n",
    "        offset_data = json.load(f)\n",
    "except FileNotFoundError as e:\n",
    "    print(\"Offset data file not found. Please check you have downloaded the data.\")\n",
    "    raise e\n",
    "\n",
    "calc_settings = {\n",
    "    \"calc_name\": \"NFF\",\n",
    "    \"optimizer\": \"BFGS\",\n",
    "    \"chem_pots\": {\"Sr\": -2, \"Ti\": 0, \"O\": 0},\n",
    "    \"relax_atoms\": True,\n",
    "    \"relax_steps\": 20,\n",
    "    \"offset\": True,\n",
    "    \"offset_data\": offset_data,\n",
    "}\n",
    "\n",
    "system_settings = {\n",
    "    \"surface_name\": surface_name,\n",
    "    \"surface_depth\": 1,\n",
    "    \"cutoff\": 5.0,\n",
    "    \"near_reduce\": 0.01,\n",
    "    \"planar_distance\": 1.5,\n",
    "    \"no_obtuse_hollow\": True,\n",
    "    \"ads_site_type\": \"all\",\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up NFF Calculator. Here, we are using the same neural network weights from our Zenodo dataset (https://zenodo.org/record/7927039). The ensemble requires an `offset_data.json` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/dux/NeuralForceField/models\n",
      "offset data: {'bulk_energies': {'O': -0.17747231201, 'Sr': -0.06043637668, 'SrTiO3': -1.470008697358702}, 'stoidict': {'Sr': 0.49995161381315867, 'Ti': -0.0637500349111578, 'O': -0.31241304903276834, 'offset': -11.324476454433157}, 'stoics': {'Sr': 1, 'Ti': 1, 'O': 3}, 'ref_formula': 'SrTiO3', 'ref_element': 'Ti'} is set from parameters\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'calc_name': 'NFF',\n",
       " 'optimizer': 'BFGS',\n",
       " 'chem_pots': {'Sr': -2, 'Ti': 0, 'O': 0},\n",
       " 'relax_atoms': True,\n",
       " 'relax_steps': 20,\n",
       " 'offset': True,\n",
       " 'offset_data': {'bulk_energies': {'O': -0.17747231201,\n",
       "   'Sr': -0.06043637668,\n",
       "   'SrTiO3': -1.470008697358702},\n",
       "  'stoidict': {'Sr': 0.49995161381315867,\n",
       "   'Ti': -0.0637500349111578,\n",
       "   'O': -0.31241304903276834,\n",
       "   'offset': -11.324476454433157},\n",
       "  'stoics': {'Sr': 1, 'Ti': 1, 'O': 3},\n",
       "  'ref_formula': 'SrTiO3',\n",
       "  'ref_element': 'Ti'}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from nff.io.ase_calcs import NeuralFF\n",
    "from nff.utils.cuda import cuda_devices_sorted_by_free_mem\n",
    "\n",
    "from mcmc.calculators import EnsembleNFFSurface\n",
    "\n",
    "DEVICE = f\"cuda:{cuda_devices_sorted_by_free_mem()[-1]}\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# requires an ensemble of models in this path and an `offset_data.json` file\n",
    "nnids = [\"model01\", \"model02\", \"model03\"]\n",
    "model_dirs = [\n",
    "    os.path.join(\n",
    "        \"../tutorials\",\n",
    "        \"data/SrTiO3_001/nff\",\n",
    "        str(x),\n",
    "        \"best_model\",\n",
    "    )\n",
    "    for x in nnids\n",
    "]\n",
    "\n",
    "models = []\n",
    "for modeldir in model_dirs:\n",
    "    m = NeuralFF.from_file(modeldir, device=DEVICE).model\n",
    "    models.append(m)\n",
    "\n",
    "nff_surf_calc = EnsembleNFFSurface(models, device=DEVICE)\n",
    "nff_surf_calc.set(**calc_settings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test the reference slabs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Step     Time          Energy         fmax\n",
      "BFGS:    0 16:30:47     -570.127991        0.7374\n",
      "BFGS:    1 16:30:47     -570.142517        0.6409\n",
      "BFGS:    2 16:30:47     -570.188354        0.1144\n",
      "BFGS:    3 16:30:47     -570.188721        0.1004\n",
      "BFGS:    4 16:30:47     -570.189758        0.0113\n",
      "BFGS:    5 16:30:48     -570.189758        0.0097\n",
      "      Step     Time          Energy         fmax\n",
      "BFGS:    0 16:30:48     -467.525543        0.1416\n",
      "BFGS:    1 16:30:48     -467.526703        0.1320\n",
      "BFGS:    2 16:30:48     -467.534088        0.0035\n",
      "      Step     Time          Energy         fmax\n",
      "BFGS:    0 16:30:48     -518.694092        0.7792\n",
      "BFGS:    1 16:30:49     -518.722046        0.6676\n",
      "BFGS:    2 16:30:49     -518.753479        1.1360\n",
      "BFGS:    3 16:30:49     -518.766785        0.3977\n",
      "BFGS:    4 16:30:49     -518.778076        0.2571\n",
      "BFGS:    5 16:30:49     -518.780762        0.2640\n",
      "BFGS:    6 16:30:49     -518.782776        0.0935\n",
      "BFGS:    7 16:30:50     -518.783081        0.0494\n",
      "BFGS:    8 16:30:50     -518.783203        0.0442\n",
      "BFGS:    9 16:30:50     -518.783386        0.0503\n",
      "BFGS:   10 16:30:50     -518.783447        0.0373\n",
      "BFGS:   11 16:30:50     -518.783508        0.0181\n",
      "BFGS:   12 16:30:50     -518.783508        0.0108\n",
      "BFGS:   13 16:30:51     -518.783630        0.0085\n"
     ]
    }
   ],
   "source": [
    "ref_slab_batches = [\n",
    "    get_atoms_batch(\n",
    "        slab,\n",
    "        system_settings[\"cutoff\"],\n",
    "        DEVICE,\n",
    "        props={\"energy\": 0, \"energy_grad\": []},\n",
    "    )\n",
    "    for slab in ref_slabs\n",
    "]\n",
    "\n",
    "ref_surfs = []\n",
    "for ref_slab_batch in ref_slab_batches:\n",
    "    ref_surf = SurfaceSystem(\n",
    "        ref_slab_batch,\n",
    "        calc=nff_surf_calc,\n",
    "        system_settings=system_settings,\n",
    "        save_folder=run_folder,\n",
    "    )\n",
    "    ref_surfs.append(ref_surf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The below should output:\n",
    "```\n",
    "energy of reference slab is [35.931]\n",
    "energy of reference slab is [12.478]\n",
    "energy of reference slab is [-4.876]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "energy of reference slab is [35.931]\n",
      "energy of reference slab is [12.478]\n",
      "energy of reference slab is [-4.876]\n"
     ]
    }
   ],
   "source": [
    "for surf in ref_surfs:\n",
    "    surf_energy = surf.get_surface_energy()\n",
    "    print(f\"energy of reference slab is {surf_energy}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('surface_sampling': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e3e0723b7fd9866ee8f9ae4f62931968cf8456ef2195b337b8930ae9f61708cf"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
